import re
from typing import Any, Dict, Optional, Tuple
import pandas as pd
from omegaconf import DictConfig
import numpy as np


def apply_ordinal_encoding(dataframe, indexes):
    """
    Apply ordinal encoding to the dataframe based on the provided indexes.
    
    Parameters:
    - dataframe: The pandas DataFrame to encode.
    - indexes: A dictionary mapping column names to their respective encoding indexes.
    
    Returns:
    - The encoded dataframe.
    """
    for column, index in indexes.items():
        # Map each value to its ordinal encoding based on the index
        dataframe[column] = dataframe[column].map(lambda x: index.index(x) if x in index else None)
    return dataframe

def get_avg_std(
    data_df: pd.DataFrame,
    exclude_seed: bool = True,
    exclude_refutation_if_nan: bool = True,
) -> tuple[pd.DataFrame, pd.DataFrame]:
    """
    Calculate the average and standard deviation of a DataFrame.

    Parameters:
    - data_df (pd.DataFrame): The DataFrame to calculate the statistics for.
    - exclude_seed (bool): Whether to exclude the "seed" column from the calculation. Defaults to True.
    - exclude_refutation_if_nan (bool): Whether to exclude the "Refutation Time" column if it contains only NaN values. Defaults to True.

    Returns:
    - Tuple[pd.Series, pd.Series]: A tuple containing the average and standard deviation as pandas Series.
    """

    # Exclude seed columns
    if exclude_seed and "seed" in data_df.columns:
        data_df.drop(columns=["seed"], inplace=True)
    # If the "Refutation Time" columns has only NaN values, drop it.
    if exclude_refutation_if_nan:
        if "Refutation Time" in data_df.columns:
            data_df.drop(columns=["Refutation Time"], inplace=True)
        if "Refutation" in data_df.columns:
            data_df.drop(columns=["Refutation"], inplace=True)

    avg_df = pd.DataFrame(
        {
            **{f"{col}_average": data_df[col].mean() for col in data_df.columns},
        },
        index=[0],
    )

    std_df = pd.DataFrame(
        {
            **{f"{col}_std": data_df[col].std() for col in data_df.columns},
        },
        index=[0],
    )

    return avg_df, std_df


def only_obs_data(
    data: pd.DataFrame, int_table: pd.DataFrame
) -> Tuple[pd.DataFrame, pd.DataFrame]:
    """Remove from dataframe those rows that has 1 in the interventional table, as they contains the interventional data.

    Args:
        data (pd.DataFrame): Dataframe with the observational and interventional data.
        int_table (pd.DataFrame): Dataframe with the interventional table.

    Returns:
        Tuple[pd.DataFrame, pd.DataFrame]: Dataframe with only the observational data.
    """
    return data[~(int_table == 1).any(axis=1)], int_table[~(int_table == 1).any(axis=1)]


def extract_empirical_distribution(
    samples: pd.DataFrame, outcome: str, evidence: Optional[dict[str, float]] = None
) -> pd.DataFrame:
    """Extract the empirical distribution of the outcome given the evidence."""

    # Select samples that match the evidence
    if evidence is not None:
        for col in evidence:
            ev_value = evidence[col]
            mask = samples[col] == ev_value
            samples = samples[mask]

    # Select the outcome column
    samples = samples[outcome]

    distr = samples.value_counts(normalize=True).sort_index()
    return distr


def quantize_data(
    df: pd.DataFrame, quant_steps: int, columns_to_discretize: list[str]
) -> pd.DataFrame:
    bin_dict = {}
    for column in columns_to_discretize:
        df[column], bins = pd.cut(df[column], quant_steps, labels=False, retbins=True)
        df[column] = df[column].map(lambda x: bins[x])
        bin_dict[column] = bins

    return df, bin_dict


def normalize_data(
    df: pd.DataFrame,
    max_values: Optional[list[float]] = None,
    min_values: Optional[list[float]] = None,
) -> pd.DataFrame:
    """
    Normalize the data in a specified range.

    Parameters:
    df (pd.DataFrame): The input DataFrame.

    Returns:
    pd.DataFrame: The DataFrame with normalized data.
    """
    # Get the maximum value per every column and put it on a list
    if max_values is None:
        max_values = df.max().values
    # Get the minimum value per every column and put it on a list
    if min_values is None:
        min_values = df.min().values

    # Apply the normalization
    for i, column in enumerate(df.columns):
        df[column] = 2 * (df[column] - min_values[i]) / (max_values[i] - min_values[i]) - 1
        

    # If there are columns with constant values, replace NaNs with 0s
    df = df.fillna(1)
    return df, max_values, min_values


def conf2dict(conf: DictConfig) -> Dict[str, Any]:
    """Convert an hydra config to a regular python dictionary.

    Args:
        conf (DictConfig): Hydra configuration as a DictConfig

    Returns:
        Dict[str, Any]: Identical content, but as a dictionary.
    """
    new_dict = {}
    for key in conf:
        new_dict[key] = conf[key]

    return new_dict

def insert_in_bucket(value: float, buckets: np.ndarray) -> float:

    distances = [abs(value - bucket) for bucket in buckets]
    
    min_val = min(distances)
    
    if min_val > 1e-4:
        raise ValueError("Value does not fit into any bucket. Data normalization is bugged.")
    
    return buckets[distances.index(min_val)]


def preprocess_data(
    train_data: pd.DataFrame,
    ground_truth: list[pd.DataFrame],
    treatment_list: list[Dict[str, Any]],
    evidence: Dict[str, Any] = {},
    quantize: bool = False,
    quant_steps: int = 20,
    dataset_size: int = -1,
    use_interventional_data: bool = False,
    int_table: Optional[pd.DataFrame] = None,
    balance_variable: Optional[str] = None,
) -> tuple[pd.DataFrame, tuple[list[pd.DataFrame]], list[Dict[str, Any]], Dict[str, Any]]:
    """
    Preprocess the data by normalizing and optionally quantizing training 
    data, ground truth data, treatment and control values, and evidence values.

    Parameters:
        train_data (pd.DataFrame): The training data to be preprocessed.
        ground_truth (tuple[list[pd.DataFrame]]): The ground truth data to be preprocessed.
        treatment (dict): The treatment data to be preprocessed.
        quantize (bool, optional): Whether to quantize the data. Defaults to False.
        quant_steps (int, optional): The number of quantization steps. Defaults to 50.

    Returns:
        tuple[pd.DataFrame, tuple[list[pd.DataFrame]], dict]: A tuple containing the preprocessed training data, preprocessed ground truth data, and preprocessed treatment data.
    """
    if not use_interventional_data:
        if train_data.shape[0] != int_table.shape[0]:
            raise ValueError("Train data and interventional table have a different number of rows.")
        else:
            train_data, int_table = only_obs_data(train_data, int_table)

    # Shuffle the rows of data and interventional table using the same permutation
    permutation = np.random.permutation(train_data.index)
    train_data = train_data.loc[permutation].reset_index(drop = True)
    int_table = int_table.loc[permutation].reset_index(drop = True)

    # Some models cannot deal with big datasets, so we just discard the number of datapoints.
    if balance_variable is None:
        if dataset_size != -1:
            train_data = train_data[:dataset_size]
            int_table = int_table[:dataset_size]
    else:
        # IF we want to balance the dataset, we take the same number of samples for each value of the balance variable
        # This helps particularly those learning-based causal models.
        balanced_train_data = pd.DataFrame()
        balanced_int_table = pd.DataFrame()
        uniques = train_data[balance_variable].unique()
        num_uniques = len(uniques)
        group_size = int(dataset_size/ num_uniques)
        for unique in uniques:
            mask = train_data[balance_variable] == unique
            balanced_train_data = pd.concat([balanced_train_data, train_data[mask][:group_size]])
            balanced_int_table = pd.concat([balanced_int_table, int_table[mask][:group_size]])

        train_data = balanced_train_data.reset_index(drop=True)
        int_table = balanced_int_table.reset_index(drop=True)

    # Shuffle the rows of data and interventional table using the same permutation
    permutation = np.random.permutation(train_data.index)
    train_data = train_data.loc[permutation].reset_index(drop = True)
    int_table = int_table.loc[permutation].reset_index(drop = True)

    # Convert the treatment_list in hydra format (not modifiable) to a regular python dictionary
    treatment_list = [conf2dict(entry) for entry in treatment_list]

    # Check if all dataframes have the same columns. Eventually drop
    columns = train_data.columns
    for i, gt_dict in enumerate(ground_truth):
        # Each single gt_dict contains the ground truth values for an effect estimation experiment
        # Therefore, they have a "treated_data" and "control_data" fields containing
        # data related to treated and control population.
        df1, df2 = gt_dict["treated_data"], gt_dict["control_data"]

        difference1 = [item for item in df1.columns if item not in columns]
        difference2 = [item for item in columns if item not in df1.columns]
        if len(df1.columns) != len(columns) or difference1 or difference2:
            gt_dict["treated_data"] = df1.drop(columns=difference1)
            print(f"Columns {difference1} were dropped from the dataframe.")

        difference1 = [item for item in df2.columns if item not in columns]
        difference2 = [item for item in columns if item not in df2.columns]
        if len(df2.columns) != len(columns) or difference1 or difference2:
            gt_dict["control_data"] = df2.drop(columns=difference1)
            print(f"Columns {difference1} were dropped from the dataframe.")

    # Identify discrete variables from the dataframe
    discrete_columns = train_data.columns[train_data.nunique() < quant_steps]
    continuous_columns = train_data.columns[train_data.nunique() >= quant_steps]

    # Embed discrete variable to ordinal encoding. In this way the normalization will put them in equallly spaced buckets.
    indexes = {}
    for column in discrete_columns:
        if train_data[column].nunique() > 2:
            indexes[column] = sorted(list(train_data[column].unique()))

    train_data = apply_ordinal_encoding(train_data, indexes)

    # Normalize the first dataframe, and apply the same normalization to the rest of the dataframes
    train_data, max_values, min_values = normalize_data(train_data)

    # NOTE: The ground truth should not be quantized.
    for i, gt_dict in enumerate(ground_truth):
        # Ordinal encoding for discrete columns
        gt_dict["treated_data"] = apply_ordinal_encoding(gt_dict["treated_data"], indexes)
        gt_dict["control_data"] = apply_ordinal_encoding(gt_dict["control_data"], indexes)


        gt_dict["treated_data"], _, _ = normalize_data(
            gt_dict["treated_data"], max_values, min_values
        )
        gt_dict["control_data"], _, _ = normalize_data(
            gt_dict["control_data"], max_values, min_values
        )

    for k, tr in enumerate(treatment_list):
        # Regular expression to match the variable name and value for treated population...
        match = re.search(r"do\(([^=]+)=(\d+)\)", tr["treatment"])
        tr_name = match.group(1) if match else None
        tr_value = float(match.group(2)) if match else None
        # ... and for control population.
        match = re.search(r"do\(([^=]+)=(\d+)\)", tr["control"])
        ctrl_name = match.group(1) if match else None
        ctrl_value = float(match.group(2)) if match else None

        if tr_name != ctrl_name:
            raise ValueError(
                f"The treatment and control variables are different: {tr_name} and {ctrl_name}."
            )

        # Extract the index of the treatment variable
        idx = list(train_data.columns).index(tr_name)

        if tr_name in indexes:
            # If it is discrete, it has to be put in the right bucket
            tr_value = indexes[tr_name].index(tr_value)
            ctrl_value = indexes[tr_name].index(ctrl_value)


        tr_value = float(
                2
                * (tr_value - min_values[idx])
                / (max_values[idx] - min_values[idx])
                - 1
            )
        ctrl_value = float(
                2
                * (ctrl_value - min_values[idx])
                / (max_values[idx] - min_values[idx])
                - 1
            )
        
        if tr_name in discrete_columns:
            # If it is discrete, it has to be put in the right bucket. The reason for this is just that during the normalization
            # there may be a small numerical error, yielding values not coinciding with the discrete domain of the variable.
            # SO, this is more of an "apparently useless" bucketing operation.
            tr_value = insert_in_bucket(tr_value, train_data[tr_name].unique())
            ctrl_value = insert_in_bucket(ctrl_value, train_data[tr_name].unique())
            
        treatment_list[k].update(
            {
                "treatment_var": tr_name,
                "control_var": ctrl_name,
                "treatment_value": tr_value,
                "control_value": ctrl_value,
            }
        )

    # Normalize evidence
    for key in evidence:
        if key in indexes:
            evidence[key] = indexes[key].index(evidence[key])

        idx = list(train_data.columns).index(key)
        evidence[key] = float(
                    2
                    * (evidence[key] - min_values[idx])
                    / (max_values[idx] - min_values[idx])
                    - 1
                )

    # If quantize is True, apply quantization to the dataframes, but the bucket
    # subdivision should be the same
    # NOTE: Also the treatment and evidence should be quantized, so that the 
    # treatment value is present in the dataframes.
    if quantize:
        train_data, bin_list = quantize_data(
            train_data, quant_steps, continuous_columns
        )
        # Quantize treatment in the same bins
        for k, tr_dict in enumerate(treatment_list):
            tr_var, ctrl_var = tr_dict["treatment_var"], tr_dict["control_var"]
            if tr_var in continuous_columns:
                # Since continuous columns have been quantized, we have to
                # discretize the value of the treatment too.
                treatment_list[k]["treatment_value"] = insert_in_bucket(tr_value, train_data[tr_name].unique())
                treatment_list[k]["control_value"] = insert_in_bucket(ctrl_value, train_data[tr_name].unique())

        # Quantize also the evidence with same bins.
        for key in evidence:
            if evidence[key] in continuous_columns:
                evidence[key] = pd.cut(
                    evidence[key],
                    bin_list[key],
                    labels=False,
                )

    return train_data, int_table, ground_truth, treatment_list, evidence


def get_ground_truth_distributions(
    ground_truth: list[pd.DataFrame],
    treatment_list: dict[str, Any],
    outcome: str,
    evidence: dict[str, Any],
) -> dict[str, Any]:
    
    for k, trt in enumerate(treatment_list):
        treatment_var = trt["treatment_var"]
        if treatment_var in evidence:
            raise ValueError(f"The treatment variable {trt} cannot be in the evidence.")
        
        # Extract ground truth treated and control data
        treated_data = ground_truth[k]["treated_data"]
        control_data = ground_truth[k]["control_data"]

        # For every treatment, we compute the interventional distribution...
        treated_int_dist = extract_empirical_distribution(treated_data, outcome)
        control_int_dist = extract_empirical_distribution(control_data, outcome)

        # ... and the conditional interventional distribution
        treated_cond_int_dist = extract_empirical_distribution(
           treated_data, outcome, evidence
        )
        control_cond_int_dist = extract_empirical_distribution(
            control_data, outcome, evidence
        )

        trt.update({
            "Treated Interventional Distribution": treated_int_dist,
            "Control Interventional Distribution": control_int_dist,
            "Treated Conditional Interventional Distribution": treated_cond_int_dist,
            "Control Conditional Interventional Distribution": control_cond_int_dist,
        })

    return treatment_list
